{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "\n",
    "This tutorial will guide you through estimating normal vectors and principal curvatures from 3D point clouds using classic PCA and Jet fitting methods as well as our recent DeepFit method. \n",
    "\n",
    "\n",
    "Surface normals and curvatures are a very important properties in shape analysis and are widely used in different fields like computer graphics, computer vision and more. \n",
    "\n",
    "\n",
    "### The math\n",
    "3D point clouds are represented as a matrix of $(x, y, z)$ coordinates of points in 3D space. \n",
    "If you are unfamiliar with plane fitting and jets, this section provides a short background. For more in-depth information see the refrences at the bottom. \n",
    "\n",
    "The following methods share the following: \n",
    "* Input: 3D point cloud  + querry point $q_i$\n",
    "* Find $q_i$'s k nearest neighbors.  \n",
    "* Do something fancey :)\n",
    "* Output: Normal vector at  query point $N_{q_i}$\n",
    "\n",
    "For estimating the normal at each point we use each point in the point cloud as the query point. \n",
    "\n",
    "\n",
    "#### PCA\n",
    "In this method we estimate the tangent plane to the underlying surface at the query point. This boils down to  solving the eigenvalue and eigenvector decomposition of the covariance matrix created from the points nearest neighbors:\n",
    "\n",
    "\\begin{equation}\n",
    "C = \\frac{1}{k}\\sum_{i=1}^{k}(p_i-\\hat{p})(p_i-\\hat{p})^T\n",
    "\\end{equation}\n",
    "\n",
    "Here $p_i$ are the neighboring points and $\\hat{p}$ is the neighbours centroid. \n",
    "The normal vector is the eugenvector associated with the smallest eigenvalue. \n",
    "\n",
    "Using this method we cannot directly estimate the principal curvatures (the principal curvatures of a plane are 0).\n",
    "\n",
    "\n",
    "#### Jet fitting\n",
    "\n",
    " An $n$-jet of the height function over a surface is given by:\n",
    "\n",
    "\\begin{equation}\n",
    "    f(x,y)=J_{\\beta,n}(x,y)= \\sum_{k=0}^{n}\\sum_{j=0}^{k}\\beta_{k-j,j}x^{k-j}y^j\n",
    "\\end{equation}\n",
    "\n",
    "Here $\\beta$ is the jet coefficients vector that consists of $N_n=(n+1)(n+2)/2$ terms.\n",
    "\n",
    "We require that every point satisfy the equation above, yielding the system of linear equations:\n",
    "\\begin{equation}\n",
    "    M\\beta = B\n",
    "\\end{equation}\n",
    "\n",
    "It is well known that the solution can be expressed in closed-form as: \n",
    "\\begin{equation}\n",
    "    \\beta = (M^TM)^{-1}M^TB\n",
    "\\end{equation}\n",
    "\n",
    "Here $M=(1, x_i, y_i, ..., x_i y_i^{n-1}, y_i^n)_{i=1,...,N_p}\\in \\mathbb{R}^{N_p \\times N_n}$ is the Vandermonde matrix and the height function vector $B=(z_1, z_2,...z_{N_p})^T \\in \\mathbb{R}^N_p$. Both represent the sampled points.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%% Imports\n"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import sys  \n",
    "import os\n",
    "import numpy as np\n",
    "sys.path.insert(0, '../utils')\n",
    "sys.path.insert(0, '../models')\n",
    "sys.path.insert(0, '../trained_models')\n",
    "import DeepFit\n",
    "import tutorial_utils as tu\n",
    "import torch\n",
    "import ipyvolume as ipv\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display\n",
    "import functools\n",
    "\n",
    "gpu_idx = 0\n",
    "device = torch.device(\"cpu\" if gpu_idx < 0 else \"cuda:%d\" % 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generating data\n",
    "3D point cloud data can be obtained from 3D sensors like LiDAR or RGBD cameras. For this tutorial we will generate a synthetic example. This way we have true ground truth parametric surface as reference for evaluating our results. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jet_order_data = 3\n",
    "n_points = 4096\n",
    "point_cloud_dataset = tu.SyntheticPointCloudDataset(n_points, jet_order_data, points_per_patch=128)\n",
    "dataloader = torch.utils.data.DataLoader(point_cloud_dataset, batch_size=256, num_workers=8)\n",
    "\n",
    "points = point_cloud_dataset.points\n",
    "gt_normals = point_cloud_dataset.gt_normals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Normal estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "jet_order_fit = 3\n",
    "for batchind, data in enumerate(dataloader):\n",
    "    points, scale_radius = data[0], data[1]\n",
    "    points = points.to(device)\n",
    "    scale_radius = scale_radius.to(device)\n",
    "    \n",
    "    # compute jet coefficients and normal estimation\n",
    "    pca_beta, n_est, _ = DeepFit.fit_Wjet(points, torch.ones_like(points[:, 0]), order=1,\n",
    "                               compute_neighbor_normals=False)\n",
    "    pca_normals = n_est if batchind==0 else torch.cat([pca_normals, n_est], 0)\n",
    "    \n",
    "    jet_beta, n_est, _ = DeepFit.fit_Wjet(points, torch.ones_like(points[:, 0]), order=jet_order_fit,\n",
    "                               compute_neighbor_normals=False)\n",
    "    jet_normals = n_est if batchind==0 else torch.cat([jet_normals, n_est], 0)\n",
    "    \n",
    "    # compute principal curvatures\n",
    "    jet_curv_est, principal_dirs = tu.compute_principal_curvatures(jet_beta)\n",
    "    jet_curv_est = jet_curv_est / scale_radius.unsqueeze(-1).repeat(1, jet_curv_est.shape[1])\n",
    "    jet_curv_est = jet_curv_est.detach().cpu()\n",
    "    jet_curvatures = jet_curv_est if batchind==0 else torch.cat([jet_curvatures, jet_curv_est], 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The normals are unoriented so we will now flip them upwards (assuming that the positive z axis points up)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_sign = torch.sign(torch.sum(pca_normals*\n",
    "                              torch.tensor([0., 0., 1.], device=pca_normals.device).repeat([n_points, 1]), dim=1)).unsqueeze(-1)\n",
    "pca_normals = n_sign * pca_normals\n",
    "pca_normals = pca_normals.detach().cpu().numpy()\n",
    "\n",
    "n_sign = torch.sign(torch.sum(jet_normals*\n",
    "                              torch.tensor([0., 0., 1.], device=jet_normals.device).repeat([n_points, 1]), dim=1)).unsqueeze(-1)\n",
    "jet_normals = n_sign * jet_normals\n",
    "jet_normals = jet_normals.detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_n_pca = tu.normal2rgb(pca_normals) #convert normal vectors to RGB\n",
    "color_n_jet = tu.normal2rgb(jet_normals)\n",
    "color_n_gt = tu.normal2rgb(gt_normals)\n",
    "\n",
    "curvature_range_min = [-torch.mean(torch.abs(jet_curvatures[:, 1])) - torch.std(torch.abs(jet_curvatures[:, 1])), \n",
    "                                torch.mean(torch.abs(jet_curvatures[:, 1])) + torch.std(torch.abs(jet_curvatures[:, 1]))]\n",
    "curvature_range_max = [-torch.mean(torch.abs(jet_curvatures[:, 0])) - torch.std(torch.abs(jet_curvatures[:, 0])), \n",
    "                                torch.mean(np.abs(jet_curvatures[:, 0])) + torch.std(torch.abs(jet_curvatures[:, 0]))]                                   \n",
    "color_curv = tu.curvatures2rgb(jet_curvatures.clone(),  k1_range=curvature_range_max, k2_range=curvature_range_min)\n",
    "\n",
    "# make some buttons to toggle between the colors \n",
    "btn_pca = widgets.Button(description='PCA')\n",
    "btn_pca.style.button_color='lightgray'\n",
    "btn_jet = widgets.Button(description='Jet')\n",
    "btn_jet.style.button_color='lightgray'\n",
    "btn_solid = widgets.Button(description='solid')\n",
    "btn_solid.style.button_color='lightgray'\n",
    "btn_gt = widgets.Button(description='GT')\n",
    "btn_gt.style.button_color='lightgray'\n",
    "btn_curv = widgets.Button(description='curvature')\n",
    "btn_curv.style.button_color='lightgray'\n",
    "\n",
    "def update_pc_normal_color(b, scatter_h):\n",
    "    \"\"\"    \n",
    "        this function is linked to the buttons and updates the the point cloud color\n",
    "    \"\"\"\n",
    "    if b.description == 'PCA':\n",
    "        scatter_h.color = color_n_pca\n",
    "    elif b.description == 'Jet': \n",
    "        scatter_h.color = color_n_jet\n",
    "    elif b.description == 'solid':\n",
    "        scatter_h.color = 'red'\n",
    "    elif b.description == 'GT':\n",
    "        scatter_h.color = color_n_gt\n",
    "    elif b.description == 'curvature':\n",
    "        scatter_h.color = color_curv\n",
    "\n",
    "\n",
    "#plot\n",
    "fig_h = ipv.figure()\n",
    "scatter_h = ipv.pylab.scatter(point_cloud_dataset.points[:, 0], \n",
    "                              point_cloud_dataset.points[:, 1], \n",
    "                              point_cloud_dataset.points[:, 2], size=1, marker=\"sphere\", color='red')\n",
    "ipv.pylab.xyzlim(-1, 1)\n",
    "ipv.style.use('minimal')\n",
    "# button function \n",
    "btn_pca.on_click(functools.partial(update_pc_normal_color, scatter_h=scatter_h))\n",
    "btn_jet.on_click(functools.partial(update_pc_normal_color, scatter_h=scatter_h))\n",
    "btn_gt.on_click(functools.partial(update_pc_normal_color, scatter_h=scatter_h))\n",
    "btn_solid.on_click(functools.partial(update_pc_normal_color, scatter_h=scatter_h))\n",
    "btn_curv.on_click(functools.partial(update_pc_normal_color, scatter_h=scatter_h))\n",
    "\n",
    "display(widgets.VBox((widgets.HBox((btn_pca, btn_jet, btn_gt, btn_solid, btn_curv)), fig_h)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.47420675  0.04765987 -0.40615332]\n",
      " [-0.19475208  0.22290343  0.47209102]\n",
      " [-0.54465634 -0.18347211  0.03932041]\n",
      " ...\n",
      " [-0.13093506  0.3629976  -0.06708436]\n",
      " [ 0.5013776   0.18889578  0.12530105]\n",
      " [-0.12950778  0.3153689  -0.2243162 ]]\n"
     ]
    }
   ],
   "source": [
    "point_cloud_dataset = tu.SinglePointCloudDataset('./Boxy_smooth100k.xyz', points_per_patch=256)\n",
    "print(point_cloud_dataset.points)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit an n-jet (and compute normal vector)\n",
    "This may take a while, depending on the number of points. \n",
    "\n",
    "Note: this is the classic jet fitting method which uses non-weighted least squares. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "jet_order =3\n",
    "dataloader = torch.utils.data.DataLoader(point_cloud_dataset, batch_size=256, num_workers=8)\n",
    "\n",
    "for batchind, data in enumerate(dataloader, 0):\n",
    "    points = data[0]\n",
    "    data_trans = data[1]\n",
    "    scale_radius = data[-1]\n",
    "    points = points.to(device)\n",
    "    data_trans = data_trans.to(device)\n",
    "    scale_radius = scale_radius.to(device)\n",
    "    \n",
    "    beta, n_est, neighbors_n_est = DeepFit.fit_Wjet(points, torch.ones_like(points[:, 0]), order=jet_order,\n",
    "                               compute_neighbor_normals=False)\n",
    "    n_est = torch.bmm(n_est.unsqueeze(1), data_trans.transpose(2, 1)).squeeze(dim=1) # cancel out pca\n",
    "    n_est = n_est.detach().cpu()\n",
    "    normals = n_est if batchind==0 else torch.cat([normals, n_est], 0)\n",
    "    \n",
    "    curv_est, principal_dirs = tu.compute_principal_curvatures(beta)\n",
    "    curv_est = curv_est / scale_radius.unsqueeze(-1).repeat(1, curv_est.shape[1])\n",
    "    curv_est = curv_est.detach().cpu()\n",
    "    curvatures = curv_est if batchind == 0 else torch.cat([curvatures, curv_est], 0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The normals are unoriented so we will now flip them outwards (assuming that the origin is an internal point and the shape is simple)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%%\n"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "n_sign = torch.sign(torch.sum(normals*torch.tensor(point_cloud_dataset.points), dim=1)).unsqueeze(-1)\n",
    "normals = n_sign * normals\n",
    "normals = normals.detach().cpu()\n",
    "curvatures = n_sign.repeat([1, 2]) * curvatures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the point cloud\n",
    "We plot the point cloud and allow for 3 color overlays:\n",
    "* Solid - (all points have the same color)\n",
    "* Normals - We map the normal vectors to the RGB cube and use these values for coloring the point cloud.\n",
    "* Curvatures - We map the principal curvatures to the XXX colormap to RGB. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "pycharm": {
     "is_executing": false,
     "name": "#%% Visualize normals\n"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6e78bbf3b094449585700e5fb7609dea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(Button(description='Solid', style=ButtonStyle(button_color='lightgray')), Button…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "color_n = tu.normal2rgb(normals.numpy()) #convert normal vectors to RGB\n",
    "curvature_range_min = [-torch.mean(torch.abs(curvatures[:, 1])) - torch.std(torch.abs(curvatures[:, 1])), \n",
    "                                torch.mean(torch.abs(curvatures[:, 1])) + torch.std(torch.abs(curvatures[:, 1]))]\n",
    "curvature_range_max = [-torch.mean(torch.abs(curvatures[:, 0])) - torch.std(torch.abs(curvatures[:, 0])), \n",
    "                                torch.mean(np.abs(curvatures[:, 0])) + torch.std(torch.abs(curvatures[:, 0]))]                                   \n",
    "color_curv = tu.curvatures2rgb(curvatures.clone(),  k1_range=curvature_range_max, k2_range=curvature_range_min)\n",
    "\n",
    "# make some buttons to toggle between the colors \n",
    "btn_pc = widgets.Button(description='Solid')\n",
    "btn_pc.style.button_color='lightgray'\n",
    "btn_n = widgets.Button(description='normals')\n",
    "btn_n.style.button_color='lightgray'\n",
    "btn_c = widgets.Button(description='Curvatures')\n",
    "btn_c.style.button_color='lightgray'\n",
    "\n",
    "def update_pc_color_to_solid(b, scatter_h):\n",
    "    \"\"\"    \n",
    "        this function is linked to the buttons and updates the the point cloud color\n",
    "    \"\"\"\n",
    "    scatter_h.color = 'red'\n",
    "\n",
    "    \n",
    "def update_pc_color_to_normal(b, scatter_h):\n",
    "    \"\"\"    \n",
    "        this function is linked to the buttons and updates the the point cloud color\n",
    "    \"\"\"\n",
    "    scatter_h.color = color_n\n",
    "    \n",
    "def update_pc_color_to_curv(b, scatter_h):\n",
    "    \"\"\"    \n",
    "        this function is linked to the buttons and updates the the point cloud color\n",
    "    \"\"\"\n",
    "    scatter_h.color = color_curv\n",
    "\n",
    "#plot\n",
    "fig_h = ipv.figure()\n",
    "scatter_h = ipv.pylab.scatter(point_cloud_dataset.points[:, 0], \n",
    "                              point_cloud_dataset.points[:, 1], \n",
    "                              point_cloud_dataset.points[:, 2], size=0.5, marker=\"sphere\", color='red')\n",
    "ipv.pylab.xyzlim(-1, 1)\n",
    "ipv.style.use('minimal')\n",
    "# ipv.show()\n",
    "\n",
    "btn_pc.on_click(functools.partial(update_pc_color_to_solid, scatter_h=scatter_h))\n",
    "btn_n.on_click(functools.partial(update_pc_color_to_normal, scatter_h=scatter_h))\n",
    "btn_c.on_click(functools.partial(update_pc_color_to_curv, scatter_h=scatter_h))\n",
    "display(widgets.VBox((widgets.HBox((btn_pc, btn_n, btn_c)), fig_h)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit an n-jet using DeepFit (and compute normal vector and principal curvatures)\n",
    "We first load the pretrained model and its parameters. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using 3 order jet for surface fitting\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DeepFit(\n",
       "  (feat): PointNetEncoder(\n",
       "    (pointfeat): PointNetFeatures(\n",
       "      (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))\n",
       "      (conv2): Conv1d(64, 64, kernel_size=(1,), stride=(1,))\n",
       "      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (stn1): QSTN(\n",
       "        (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,))\n",
       "        (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
       "        (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))\n",
       "        (mp1): MaxPool1d(kernel_size=256, stride=256, padding=0, dilation=1, ceil_mode=False)\n",
       "        (fc1): Linear(in_features=1024, out_features=512, bias=True)\n",
       "        (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
       "        (fc3): Linear(in_features=256, out_features=4, bias=True)\n",
       "        (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (stn2): STN(\n",
       "        (conv1): Conv1d(64, 64, kernel_size=(1,), stride=(1,))\n",
       "        (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
       "        (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))\n",
       "        (mp1): MaxPool1d(kernel_size=256, stride=256, padding=0, dilation=1, ceil_mode=False)\n",
       "        (fc1): Linear(in_features=1024, out_features=512, bias=True)\n",
       "        (fc2): Linear(in_features=512, out_features=256, bias=True)\n",
       "        (fc3): Linear(in_features=256, out_features=4096, bias=True)\n",
       "        (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (bn5): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (conv2): Conv1d(64, 128, kernel_size=(1,), stride=(1,))\n",
       "    (conv3): Conv1d(128, 1024, kernel_size=(1,), stride=(1,))\n",
       "    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (conv1): Conv1d(1088, 512, kernel_size=(1,), stride=(1,))\n",
       "  (conv2): Conv1d(512, 256, kernel_size=(1,), stride=(1,))\n",
       "  (conv3): Conv1d(256, 128, kernel_size=(1,), stride=(1,))\n",
       "  (conv4): Conv1d(128, 1, kernel_size=(1,), stride=(1,))\n",
       "  (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (do): Dropout(p=0.25, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trained_model_path = '../trained_models/DeepFit'\n",
    "params = torch.load(os.path.join(trained_model_path, 'DeepFit_params.pth'))\n",
    "k_neighbors = params.points_per_patch #note you can use a different number, this is what the network trained on\n",
    "jet_order = params.jet_order\n",
    "print('Using {} order jet for surface fitting'.format(jet_order))\n",
    "model = DeepFit.DeepFit(k=1, num_points=k_neighbors, use_point_stn=params.use_point_stn,\n",
    "                        use_feat_stn=params.use_feat_stn, point_tuple=params.point_tuple, sym_op=params.sym_op,\n",
    "                        arch=params.arch, n_gaussians=params.n_gaussians, jet_order=jet_order,\n",
    "                        weight_mode=params.weight_mode, use_consistency=False)\n",
    "checkpoint = torch.load(os.path.join(trained_model_path, 'DeepFit.pth'))\n",
    "model.load_state_dict(checkpoint)\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we use DeepFit to fit a surface at each point in the point cloud. \n",
    "This may take a while, depending on the number of points. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(point_cloud_dataset, batch_size=128, num_workers=8)\n",
    "\n",
    "for batchind, data in enumerate(dataloader, 0):\n",
    "    points = data[0]\n",
    "    data_trans = data[1]\n",
    "    scale_radius = data[-1].squeeze()\n",
    "    points = points.to(device)\n",
    "    data_trans = data_trans.to(device)\n",
    "    scale_radius = scale_radius.to(device)\n",
    "    n_est, beta, weights, trans, trans2, neighbors_n_est = model.forward(points)\n",
    "    if params.use_point_stn:\n",
    "        n_est = torch.bmm(n_est.unsqueeze(1), trans.transpose(2, 1)).squeeze(dim=1) # cancel out poitnnet stn\n",
    "    n_est = torch.bmm(n_est.unsqueeze(1), data_trans.transpose(2, 1)).squeeze(dim=1)  # cancel out pca\n",
    "    n_est = n_est.detach().cpu()\n",
    "    normals = n_est if batchind == 0 else torch.cat([normals, n_est], 0)\n",
    "    \n",
    "    curv_est, principal_dirs = tu.compute_principal_curvatures(beta)\n",
    "    curv_est = curv_est / scale_radius.unsqueeze(-1).repeat(1, curv_est.shape[1])\n",
    "    curv_est = curv_est.detach().cpu()\n",
    "    curvatures = curv_est if batchind == 0 else torch.cat([curvatures, curv_est], 0)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The normals are unoriented so we will now flip them outwards (assuming that the origin is an internal point and the shape is simple)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_sign = torch.sign(torch.sum(normals*torch.tensor(point_cloud_dataset.points), dim=1)).unsqueeze(-1)\n",
    "normals = n_sign * normals\n",
    "normals = normals.detach().cpu()\n",
    "curvatures = n_sign.repeat([1, 2]) * curvatures"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the point cloud\n",
    "We plot the point cloud and allow for 3 color overlays:\n",
    "* Solid - (all points have the same color)\n",
    "* Normals - We map the normal vectors to the RGB cube and use these values for coloring the point cloud.\n",
    "* Curvatures - We map the principal curvatures to the XXX colormap to RGB. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "43dc69aad6aa406891b8b0391d28d7d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HBox(children=(Button(description='Solid', style=ButtonStyle(button_color='lightgray')), Button…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "color_n = tu.normal2rgb(normals.numpy()) #convert normal vectors to RGB\n",
    "curvature_range_min = [-torch.mean(torch.abs(curvatures[:, 1])) - torch.std(torch.abs(curvatures[:, 1])), \n",
    "                                torch.mean(torch.abs(curvatures[:, 1])) + torch.std(torch.abs(curvatures[:, 1]))]\n",
    "curvature_range_max = [-torch.mean(torch.abs(curvatures[:, 0])) - torch.std(torch.abs(curvatures[:, 0])), \n",
    "                                torch.mean(np.abs(curvatures[:, 0])) + torch.std(torch.abs(curvatures[:, 0]))]                                   \n",
    "color_curv = tu.curvatures2rgb(curvatures.clone(),  k1_range=curvature_range_max, k2_range=curvature_range_min)\n",
    "\n",
    "# make some buttons to toggle between the colors \n",
    "btn_pc = widgets.Button(description='Solid')\n",
    "btn_pc.style.button_color='lightgray'\n",
    "btn_n = widgets.Button(description='normals')\n",
    "btn_n.style.button_color='lightgray'\n",
    "btn_c = widgets.Button(description='Curvatures')\n",
    "btn_c.style.button_color='lightgray'\n",
    "\n",
    "def update_pc_color_to_solid(b, scatter_h):\n",
    "    \"\"\"    \n",
    "        this function is linked to the buttons and updates the the point cloud color\n",
    "    \"\"\"\n",
    "    scatter_h.color = 'red'\n",
    "\n",
    "    \n",
    "def update_pc_color_to_normal(b, scatter_h):\n",
    "    \"\"\"    \n",
    "        this function is linked to the buttons and updates the the point cloud color\n",
    "    \"\"\"\n",
    "    scatter_h.color = color_n\n",
    "\n",
    "def update_pc_color_to_curv(b, scatter_h):\n",
    "    \"\"\"    \n",
    "        this function is linked to the buttons and updates the the point cloud color\n",
    "    \"\"\"\n",
    "    scatter_h.color = color_curv\n",
    "\n",
    "#plot\n",
    "fig_h = ipv.figure()\n",
    "scatter_h = ipv.pylab.scatter(point_cloud_dataset.points[:, 0], \n",
    "                              point_cloud_dataset.points[:, 1], \n",
    "                              point_cloud_dataset.points[:, 2], size=0.5, marker=\"sphere\", color='red')\n",
    "ipv.pylab.xyzlim(-1, 1)\n",
    "ipv.style.use('minimal')\n",
    "# ipv.show()\n",
    "\n",
    "btn_pc.on_click(functools.partial(update_pc_color_to_solid, scatter_h=scatter_h))\n",
    "btn_n.on_click(functools.partial(update_pc_color_to_normal, scatter_h=scatter_h))\n",
    "btn_c.on_click(functools.partial(update_pc_color_to_curv, scatter_h=scatter_h))\n",
    "display(widgets.VBox((widgets.HBox((btn_pc, btn_n, btn_c)), fig_h)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "### References\n",
    "* [PCL normal estimation using PCA](http://pointclouds.org/documentation/tutorials/normal_estimation.php)\n",
    "* [CGAL normal estimation using Jets](https://doc.cgal.org/latest/Jet_fitting_3/index.html#Jet_fitting_3Mathematical)\n",
    "* [DeepFit paper](https://arxiv.org/pdf/2003.10826.pdf)\n",
    "* [Jet fitting paper](https://graphics.stanford.edu/courses/cs468-03-fall/Papers/cazals_jets.pdf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
